#%%
import torch
import numpy as np
from pathlib import Path
import sys
sys.path.append("/home/liucong/codes/repos-3D/fast3r")
from fast3r.dust3r.utils.image import load_images
from fast3r.dust3r.inference_multiview import inference
from fast3r.models.fast3r import Fast3R
from fast3r.models.multiview_dust3r_module import MultiViewDUSt3RLitModule

#%%
# --- Setup ---
# Load the model from Hugging Face
model = Fast3R.from_pretrained("/home/liucong/codes/pretrained/fast3r")  # If you have networking issues, try pre-download the HF checkpoint dir and change the path here to a local directory
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
#%%
# Create a lightweight lightning module wrapper for the model.
# This provides functions to estimate camera poses, evaluate 3D reconstruction, etc.
lit_module = MultiViewDUSt3RLitModule.load_for_inference(model)
#%%

#%%
# Set model to evaluation mode
model.eval()
lit_module.eval()
#%%
# --- Load Images ---
# Provide a list of image file paths. Images can come from different cameras and aspect ratios.
outdir = Path("/home/liucong/data/3d/fast3r/hypersim/ai_028_005_cam_00/000005_04")
filelist = "/home/liucong/data/3d/eval_data/hypersim/ai_028_005_cam_00/000005_04"
images = load_images(filelist, size=512, verbose=True)
#%%
# --- Run Inference ---
# The inference function returns a dictionary with predictions and view information.
output_dict, profiling_info = inference(
    images,
    model,
    device,
    dtype=torch.float32,  # or use torch.bfloat16 if supported
    verbose=True,
    profiling=True,
)
#%%
# --- Estimate Camera Poses ---
# This step estimates the camera-to-world (c2w) poses for each view using PnP.
poses_c2w_batch, estimated_focals = MultiViewDUSt3RLitModule.estimate_camera_poses(
    output_dict['preds'],
    niter_PnP=100,
    focal_length_estimation_method='first_view_from_global_head'
)
# poses_c2w_batch is a list; the first element contains the estimated poses for each view.
camera_poses = poses_c2w_batch[0]
#%%
# Print camera poses for all views.
for view_idx, pose in enumerate(camera_poses):
    print(f"Camera Pose for view {view_idx}:")
    print(pose.shape)  # np.array of shape (4, 4), the camera-to-world transformation matrix
#%%
# --- Extract 3D Point Clouds for Each View ---
# Each element in output_dict['preds'] corresponds to a view's point map.
for view_idx, pred in enumerate(output_dict['preds']):
    point_cloud = pred['pts3d_in_other_view'].cpu().numpy().squeeze(0)
    print(f"Point Cloud Shape for view {view_idx}: {point_cloud.shape}")  # shape: (1, 368, 512, 3), i.e., (1, Height, Width, XYZ)
# %%
